{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "# Training and Evaluating a Text Classification Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction \n",
    "\n",
    "In this notebook, we'll demonstrate how to solve a text classification task with word2vec + linear-regression model on Spark.\n",
    "\n",
    "The sample dataset we used here consists of metadata relating to books digitised by the British Library in partnership with Microsoft. It includes human generated labels for whether a book is 'fiction' or 'non-fiction'. We use this dataset to train a model for genre classification which predict whether a book is 'fiction' or 'non-fiction' based on its title.\n",
    "\n",
    "|BL record ID|Type of resource|Name|Dates associated with name|Type of name|Role|All names|Title|Variant titles|Series title|Number within series|Country of publication|Place of publication|Publisher|Date of publication|Edition|Physical description|Dewey classification|BL shelfmark|Topics|Genre|Languages|Notes|BL record ID for physical resource|classification_id|user_id|created_at|subject_ids|annotator_date_pub|annotator_normalised_date_pub|annotator_edition_statement|annotator_genre|annotator_FAST_genre_terms|annotator_FAST_subject_terms|annotator_comments|annotator_main_language|annotator_other_languages_summaries|annotator_summaries_language|annotator_translation|annotator_original_language|annotator_publisher|annotator_place_pub|annotator_country|annotator_title|Link to digitised book|annotated|\n",
    "|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|\n",
    "|014602826|Monograph|Yearsley, Ann|1753-1806|person||More, Hannah, 1745-1833 [person] ; Yearsley, Ann, 1753-1806 [person]|Poems on several occasions [With a prefatory letter by Hannah More.]||||England|London||1786|Fourth edition MANUSCRIPT note|||Digital Store 11644.d.32|||English||003996603||||||||||||||||||||||False|\n",
    "|014602830|Monograph|A, T.||person||Oldham, John, 1653-1683 [person] ; A, T. [person]|A Satyr against Vertue. (A poem: supposed to be spoken by a Town-Hector [By John Oldham. The preface signed: T. A.])||||England|London||1679||15 pages (4°)||Digital Store 11602.ee.10. (2.)|||English||000001143||||||||||||||||||||||False|\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 1: Load the Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Notebook Configurations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Install libraries\n",
    "In this notebook, we'll use `wordcloud` which first needs to be installed. The PySpark kernel will be restarted after `%pip install`, thus we need to install it before we run any other cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install wordcloud for text visualization\n",
    "%pip install wordcloud"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**By defining below parameters, we can apply this notebook on different datasets easily.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"33\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "IS_CUSTOM_DATA = False  # if True, dataset has to be uploaded manually by user\n",
    "DATA_FOLDER = \"Files/title-genre-classification\"\n",
    "DATA_FILE = \"blbooksgenre.csv\"\n",
    "\n",
    "# data schema\n",
    "TEXT_COL = \"Title\"\n",
    "LABEL_COL = \"annotator_genre\"\n",
    "LABELS = [\"Fiction\", \"Non-fiction\"]\n",
    "\n",
    "EXPERIMENT_NAME = \"aisample-textclassification\"  # mlflow experiment name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "We also define some hyper-parameters for model training (*DON'T modify these parameters unless you are aware of the meaning*)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"34\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# hyper-params\n",
    "word2vec_size = 128\n",
    "min_word_count = 3\n",
    "max_iter = 10\n",
    "k_folds = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from itertools import chain\n",
    "\n",
    "from wordcloud import WordCloud\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "import pyspark.sql.functions as F\n",
    "\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.feature import *\n",
    "from pyspark.ml.tuning import CrossValidator, ParamGridBuilder\n",
    "from pyspark.ml.classification import LogisticRegression\n",
    "from pyspark.ml.evaluation import (\n",
    "    BinaryClassificationEvaluator,\n",
    "    MulticlassClassificationEvaluator,\n",
    ")\n",
    "\n",
    "from synapse.ml.stages import ClassBalancer\n",
    "from synapse.ml.train import ComputeModelStatistics\n",
    "\n",
    "import mlflow\n",
    "import trident.mlflow\n",
    "from trident.mlflow import get_sds_url"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Download dataset and upload to lakehouse\n",
    "\n",
    "**Please add a lakehouse to the notebook before running it.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"35\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "if not IS_CUSTOM_DATA:\n",
    "    # Download demo data files into lakehouse if not exist\n",
    "    import os, requests\n",
    "\n",
    "    remote_url = \"https://synapseaisolutionsa.blob.core.windows.net/public/Title_Genre_Classification\"\n",
    "    fname = \"blbooksgenre.csv\"\n",
    "    download_path = f\"/lakehouse/default/{DATA_FOLDER}/raw\"\n",
    "\n",
    "    if not os.path.exists(\"/lakehouse/default\"):\n",
    "        raise FileNotFoundError(\"Default lakehouse not found, please add a lakehouse.\")\n",
    "    os.makedirs(download_path, exist_ok=True)\n",
    "    if not os.path.exists(f\"{download_path}/{fname}\"):\n",
    "        r = requests.get(f\"{remote_url}/{fname}\", timeout=30)\n",
    "        with open(f\"{download_path}/{fname}\", \"wb\") as f:\n",
    "            f.write(r.content)\n",
    "    print(\"Downloaded demo data files into lakehouse.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Read data from lakehouse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"36\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "raw_df = spark.read.csv(f\"{DATA_FOLDER}/raw/{DATA_FILE}\", header=True, inferSchema=True)\n",
    "\n",
    "display(raw_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 2: Data Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Data clean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"37\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "df = (\n",
    "    raw_df.select([TEXT_COL, LABEL_COL])\n",
    "    .where(F.col(LABEL_COL).isin(LABELS))\n",
    "    .dropDuplicates([TEXT_COL])\n",
    "    .cache()\n",
    ")\n",
    "\n",
    "display(df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Deal with unblanced data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"38\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "cb = ClassBalancer().setInputCol(LABEL_COL)\n",
    "\n",
    "df = cb.fit(df).transform(df)\n",
    "display(df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"39\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "## text transformer\n",
    "tokenizer = Tokenizer(inputCol=TEXT_COL, outputCol=\"tokens\")\n",
    "stopwords_remover = StopWordsRemover(inputCol=\"tokens\", outputCol=\"filtered_tokens\")\n",
    "\n",
    "## build the pipeline\n",
    "pipeline = Pipeline(stages=[tokenizer, stopwords_remover])\n",
    "\n",
    "token_df = pipeline.fit(df).transform(df)\n",
    "\n",
    "display(token_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Visualization\n",
    "\n",
    "Display wordcloud for each class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"40\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# WordCloud\n",
    "for label in LABELS:\n",
    "    tokens = (\n",
    "        token_df.where(F.col(LABEL_COL) == label)\n",
    "        .select(F.explode(\"filtered_tokens\").alias(\"token\"))\n",
    "        .where(F.col(\"token\").rlike(r\"^\\w+$\"))\n",
    "    )\n",
    "\n",
    "    top50_tokens = (\n",
    "        tokens.groupBy(\"token\").count().orderBy(F.desc(\"count\")).limit(50).collect()\n",
    "    )\n",
    "\n",
    "    # Generate a word cloud image\n",
    "    wordcloud = WordCloud(\n",
    "        scale=10,\n",
    "        background_color=\"white\",\n",
    "        random_state=42,  # Make sure the output is always the same for the same input\n",
    "    ).generate_from_frequencies(dict(top50_tokens))\n",
    "\n",
    "    # Display the generated image the matplotlib way:\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    plt.title(label, fontsize=20)\n",
    "    plt.axis(\"off\")\n",
    "    plt.imshow(wordcloud, interpolation=\"bilinear\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Vectorize\n",
    "\n",
    "We use word2vec to vectorize text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"41\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "## label transformer\n",
    "label_indexer = StringIndexer(inputCol=LABEL_COL, outputCol=\"labelIdx\")\n",
    "vectorizer = Word2Vec(\n",
    "    vectorSize=word2vec_size,\n",
    "    minCount=min_word_count,\n",
    "    inputCol=\"filtered_tokens\",\n",
    "    outputCol=\"features\",\n",
    ")\n",
    "\n",
    "## build the pipeline\n",
    "pipeline = Pipeline(stages=[label_indexer, vectorizer])\n",
    "vec_df = (\n",
    "    pipeline.fit(token_df)\n",
    "    .transform(token_df)\n",
    "    .select([TEXT_COL, LABEL_COL, \"features\", \"labelIdx\", \"weight\"])\n",
    ")\n",
    "\n",
    "display(vec_df.limit(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 3: Model Training and Evaluation\n",
    "\n",
    "We have cleaned the dataset, dealed with unbalanced data, tokenized the text, displayed word cloud and vectorized the text.\n",
    "\n",
    "Next, we'll train a linear regression model to classify the vectorized text."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Split dataset into train and test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"42\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "(train_df, test_df) = vec_df.randomSplit((0.8, 0.2), seed=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Create the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"43\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "lr = (\n",
    "    LogisticRegression()\n",
    "    .setMaxIter(max_iter)\n",
    "    .setFeaturesCol(\"features\")\n",
    "    .setLabelCol(\"labelIdx\")\n",
    "    .setWeightCol(\"weight\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Train model with cross validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"44\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "param_grid = (\n",
    "    ParamGridBuilder()\n",
    "    .addGrid(lr.regParam, [0.03, 0.1, 0.3])\n",
    "    .addGrid(lr.elasticNetParam, [0.0, 0.1, 0.2])\n",
    "    .build()\n",
    ")\n",
    "\n",
    "if len(LABELS) > 2:\n",
    "    evaluator_cls = MulticlassClassificationEvaluator\n",
    "    evaluator_metrics = [\"f1\", \"accuracy\"]\n",
    "else:\n",
    "    evaluator_cls = BinaryClassificationEvaluator\n",
    "    evaluator_metrics = [\"areaUnderROC\", \"areaUnderPR\"]\n",
    "evaluator = evaluator_cls(labelCol=\"labelIdx\", weightCol=\"weight\")\n",
    "\n",
    "crossval = CrossValidator(\n",
    "    estimator=lr, estimatorParamMaps=param_grid, evaluator=evaluator, numFolds=k_folds\n",
    ")\n",
    "\n",
    "model = crossval.fit(train_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Evaluate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"46\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "predictions = model.transform(test_df)\n",
    "\n",
    "display(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"47\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "log_metrics = {}\n",
    "for metric in evaluator_metrics:\n",
    "    value = evaluator.evaluate(predictions, {evaluator.metricName: metric})\n",
    "    log_metrics[metric] = value\n",
    "    print(f\"{metric}: {value:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"48\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "metrics = ComputeModelStatistics(\n",
    "    evaluationMetric=\"classification\", labelCol=\"labelIdx\", scoredLabelsCol=\"prediction\"\n",
    ").transform(predictions)\n",
    "display(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "advisor": {
     "adviceMetadata": "{\"artifactId\":\"2869283e-cc68-4bec-a46e-95bf731d931f\",\"activityId\":\"aaa495e9-53ed-4e64-85eb-900ab7d8eff4\",\"applicationId\":\"application_1657773129532_0001\",\"jobGroupId\":\"49\",\"advices\":{}}"
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# collect confusion matrix value\n",
    "cm = metrics.select(\"confusion_matrix\").collect()[0][0].toArray()\n",
    "print(cm)\n",
    "\n",
    "# plot confusion matrix\n",
    "sns.set(rc={\"figure.figsize\": (6, 4.5)})\n",
    "ax = sns.heatmap(cm, annot=True, fmt=\".20g\")\n",
    "ax.set_title(\"Confusion Matrix\")\n",
    "ax.set_xlabel(\"Predicted label\")\n",
    "ax.set_ylabel(\"True label\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Log and Load Model with MLFlow\n",
    "Now we get a pretty good model, we can save it for later use. Here we use mlflow to log metrics/models, and load models back for prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup mlflow\n",
    "mlflow.set_tracking_uri(get_sds_url())\n",
    "mlflow.set_registry_uri(get_sds_url())\n",
    "mlflow.set_experiment(EXPERIMENT_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# log model, metrics and params\n",
    "with mlflow.start_run() as run:\n",
    "    print(\"log model:\")\n",
    "    mlflow.spark.log_model(\n",
    "        model,\n",
    "        f\"{EXPERIMENT_NAME}-lrmodel\",\n",
    "        registered_model_name=f\"{EXPERIMENT_NAME}-lrmodel\",\n",
    "        dfs_tmpdir=\"Files/spark\",\n",
    "    )\n",
    "\n",
    "    print(\"log metrics:\")\n",
    "    mlflow.log_metrics(log_metrics)\n",
    "\n",
    "    print(\"log parameters:\")\n",
    "    mlflow.log_params(\n",
    "        {\n",
    "            \"word2vec_size\": word2vec_size,\n",
    "            \"min_word_count\": min_word_count,\n",
    "            \"max_iter\": max_iter,\n",
    "            \"k_folds\": k_folds,\n",
    "            \"DATA_FILE\": DATA_FILE,\n",
    "        }\n",
    "    )\n",
    "\n",
    "    model_uri = f\"runs:/{run.info.run_id}/{EXPERIMENT_NAME}-lrmodel\"\n",
    "    print(\"Model saved in run %s\" % run.info.run_id)\n",
    "    print(f\"Model URI: {model_uri}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model back\n",
    "loaded_model = mlflow.spark.load_model(model_uri, dfs_tmpdir=\"Files/spark\")\n",
    "\n",
    "# verify loaded model\n",
    "predictions = loaded_model.transform(test_df)\n",
    "display(predictions)"
   ]
  }
 ],
 "metadata": {
  "kernel_info": {
   "name": "synapse_pyspark"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "notebook_environment": {},
  "save_output": true,
  "spark_compute": {
   "compute_id": "/trident/default",
   "session_options": {
    "conf": {
     "spark.dynamicAllocation.enabled": "false",
     "spark.dynamicAllocation.maxExecutors": "2",
     "spark.dynamicAllocation.minExecutors": "2",
     "spark.livy.synapse.ipythonInterpreter.enabled": "true"
    },
    "driverCores": 8,
    "driverMemory": "56g",
    "enableDebugMode": false,
    "executorCores": 8,
    "executorMemory": "56g",
    "keepAliveTimeout": 30,
    "numExecutors": 5
   }
  },
  "synapse_widget": {
   "state": {},
   "version": "0.1"
  },
  "trident": {
   "lakehouse": {
    "default_lakehouse": "",
    "known_lakehouses": [
     {
      "id": ""
     }
    ]
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "8cebba326b76ca708172f0a6a24a89689a3b64f83dbd9353b827f2f4b33d3f80"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
